# Functions for running mixture IRT models stripped from mixture_irt.Rmd

library(aroma.light)
library(gtools)
library(Rcpp)
library(ggplot2)
library(tidyverse)
library(reldist)
library(ggtern)

## Rcpp-based Function....

cppFunction("
    NumericVector irt_marg_lik(NumericMatrix y, NumericVector x,
                           NumericVector a, NumericVector b) {
        NumericVector mlik(y.nrow());
        NumericVector row_log_lik(y.nrow());
        double phat;
      
        for (int k=0; k<y.nrow(); k++) mlik(k) = 0.0;
        for (int i=0; i<x.length(); i++) {
          for (int k=0; k<y.nrow(); k++) row_log_lik(k)=0.0;
          for (int j=0; j<y.ncol(); j++) {
            phat = 1.0/(1.0+exp(-a[j]-b[j]*x[i]));
            for (int k=0; k<y.nrow(); k++) {
                if (!NumericMatrix::is_na(y(k,j)))
                  row_log_lik(k) += y(k,j)==1 ? log(phat) : log(1.0-phat);
            }
          }
          for (int k=0; k<y.nrow(); k++) mlik(k) += exp(row_log_lik(k));
        }
        for (int k=0; k<y.nrow(); k++) {
          mlik(k) = log(mlik(k)) - log(x.length());          
        }
        return mlik;
}
")


## R Functions...

flip_vote_loglik <-
function(y) {
  ivp <- rep(0.5, NCOL(y))
  lk <- apply(y, 1, 
           function(yy) 
              sum(vlog(yy, ivp) + 
                  vlog(1-yy, 1-ivp), 
                  na.rm=TRUE))
  return(list(lk=lk))
}
ind_vote_data <-
function(n=100, k=10, p = rbeta(k, 5, 5)) {
   z <- matrix(runif(n*k), n, k)
   y <- t(apply(z, 1, function(zz) p>zz) + 0)
   list(y=y, p=p)
}
ind_vote_est <-
function(y, w) {
  ivp <- apply(y, 2, weighted.mean, w=w, na.rm=TRUE)
  lk <- apply(y, 1, 
           function(yy) 
              sum(vlog(yy, ivp) + 
                  vlog(1-yy, 1-ivp), 
                  na.rm=TRUE))
  return(list(ivp=ivp, lk=lk))
}
make_dat <-
function(n=2000, k=10, ndim=3) {
  x <- matrix(rnorm(n*ndim), n, ndim)
  b <- matrix(runif(ndim*k, -3, 3), ndim, k)
  a <- runif(k, -3, 3)
  phat <- plogis( cbind(1,x)  %*% rbind(a,b) )
  y <- ifelse(matrix(runif(n*k), n, k) < phat, 1, 0)
  list(y=y, phat=phat, x=x, b=b, a=a)
}
qm_irt <-
function(y, w=NULL, iters=10, ndim=3, starts=NULL) {
   if (is.null(starts)) {
     z <- ifelse(is.na(y), 0, 4*(y-0.5))
     x <- rep(Inf, NROW(z))
   }
   else {
     x <- starts$x
     aa <- cbind(1,starts$x) %*% rbind(starts$a,starts$b) 
     pr <- plogis(aa)
     z <- ifelse(is.na(y), aa, aa + 4*(y-pr))
   }
   for (i in 1:iters) {
     x_prev <- x
     res <- wpca(z, w=w, center=TRUE)
     a <- res$xMean
     x <- res$pc[,1:ndim]
     b <- res$vt[1:ndim,]
     aa <- cbind(1,x) %*% rbind(a,b) 
     pr <- plogis(aa)
     if (is.null(w)) {
       lk <- vlog(y, pr) + vlog(1-y, 1-pr)
     }
     else {
       lk <- w*(vlog(y, pr) + vlog(1-y, 1-pr))
     }
     lt <- sum(lk[which(!is.na(y))])
     z <- ifelse(is.na(y), aa, aa + 4*(y-pr))
     rmsd <- sqrt(sum((x-x_prev)^2))/sd(x)
     if (i==1 | (i %% 10 == 0) | i==iters) cat( sprintf("Iteration: %i (%9.4f, %6.4f)\n", i, lt,rmsd))
     
   }
   return(list(x=x, a=a, b=b, 
               lk=apply(lk, 1, sum, na.rm=TRUE), 
               lt=lt,
               svd=res$d[1:ndim]))
}
vlog <-
function(a,b) ifelse(b>0, a*log(b), 0.0)

em_mix_irt <- function(y, iters=40) {
  r <- NULL
  w <- rdirichlet(NROW(y), alpha=c(5,1,1))
  p <- apply(w, 2, mean)
  for (i in 1:iters) {
    cat(sprintf("Outer iteration: %i\n", i))
    r <- qm_irt(y, w=w[,1], ndim=1, iter=40, starts=r)
    ivp <- ind_vote_est(y, w=w[,2])
    fvp <- flip_vote_loglik(y)
    xsamp <- sample(r$x, size=1000, prob=w[,1], replace=TRUE)
    irt_mlk <- irt_marg_lik(y=y, x=xsamp, a=r$a, b=r$b)
#   w[,1] <- p[1]/(p[1] + p[2]*exp(ivp$lk - r$lk) + p[2]*exp(fvp$lk - r$lk))
    w[,1] <- p[1]/(p[1] + p[2]*exp(ivp$lk - irt_mlk) + p[3]*exp(fvp$lk - irt_mlk))
    w[,2] <- p[2]*exp(ivp$lk - irt_mlk)/
              (p[1] + p[2]*exp(ivp$lk - irt_mlk) + p[3]*exp(fvp$lk - irt_mlk))
    w[,3] <- 1-w[,1]-w[,2]
    p <- apply(w, 2, mean)
    cat(sprintf("Group probs = (%6.4f, %6.4f, %6.4f)\n\n",  p[1], p[2], p[3]))
  }
  return(list(w=w, ivp=ivp, irt=r))
}
